{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ac10f60b-5f0e-41b6-891e-b6d3141b5cff",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2049990]) torch.Size([214417]) torch.Size([241859])\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "import os\n",
    "from tempfile import TemporaryDirectory\n",
    "from typing import Tuple\n",
    "\n",
    "import torch\n",
    "from torch import nn, Tensor\n",
    "from torch.nn import TransformerEncoder, TransformerEncoderLayer\n",
    "from torch.utils.data import dataset\n",
    "\n",
    "from torchtext.datasets import WikiText2\n",
    "from torchtext.data.utils import get_tokenizer\n",
    "from torchtext.vocab import build_vocab_from_iterator\n",
    "\n",
    "train_iter = WikiText2(split='train')\n",
    "\n",
    "tokenizer = get_tokenizer('basic_english')\n",
    "\n",
    "vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])\n",
    "vocab.set_default_index(vocab['<unk>'])\n",
    "\n",
    "def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:\n",
    "    \"\"\"Converts raw text into a flat Tensor.\"\"\"\n",
    "    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]\n",
    "    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))\n",
    "\n",
    "# ``train_iter`` was \"consumed\" by the process of building the vocab,\n",
    "# so we have to create it again\n",
    "train_iter, val_iter, test_iter = WikiText2()\n",
    "\n",
    "train_data = data_process(train_iter)\n",
    "val_data = data_process(val_iter)\n",
    "test_data = data_process(test_iter)\n",
    "\n",
    "print(train_data.shape, val_data.shape, test_data.shape)\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "def batchify(data: Tensor, bsz: int) -> Tensor:\n",
    "    \"\"\"Divides the data into ``bsz`` separate sequences, removing extra elements\n",
    "    that wouldn't cleanly fit.\n",
    "\n",
    "    Arguments:\n",
    "        data: Tensor, shape ``[N]``\n",
    "        bsz: int, batch size\n",
    "\n",
    "    Returns:\n",
    "        Tensor of shape ``[N // bsz, bsz]``\n",
    "    \"\"\"\n",
    "    \n",
    "    seq_len = data.size(0) // bsz\n",
    "    data = data[:seq_len * bsz]\n",
    "    data = data.view(bsz, seq_len).t().contiguous()\n",
    "    return data.to(device)\n",
    "\n",
    "batch_size = 20\n",
    "eval_batch_size = 10\n",
    "train_data = batchify(train_data, batch_size)  # shape ``[seq_len, batch_size]``\n",
    "val_data = batchify(val_data, eval_batch_size)\n",
    "test_data = batchify(test_data, eval_batch_size)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a46213c8-2608-465c-b1aa-57d9d5acdba7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([102499, 20]), torch.Size([21441, 10]), torch.Size([24185, 10]))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data.shape, val_data.shape, test_data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9cdc7d9-6f19-4422-9cf5-98d4e311308a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch-venv",
   "language": "python",
   "name": "pytorch-venv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
